{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a3b44703",
   "metadata": {},
   "source": [
    "# Pytorch NN with SparkXshards for tabular"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "416b0a4b",
   "metadata": {},
   "source": [
    "Copyright 2016 The BigDL Authors."
   ]
  },
  {
   "cell_type": "raw",
   "id": "a474a628",
   "metadata": {},
   "source": [
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bf304c91",
   "metadata": {},
   "source": [
    "SparkXshards in Orca allows users to process large-scale dataset using existing Python codes in a distributed and data-parallel fashion, as shown below. This notebook is an example of how to train a pytorch model using data of SparkXshards on Orca. \n",
    "\n",
    "It is adapted from [[TPS-5] Pytorch NN for tabular - step by step](https://www.kaggle.com/code/remekkinas/tps-5-pytorch-nn-for-tabular-step-by-step/notebook)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5405faa0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Use seaborn for plots\n",
    "!pip install seaborn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "355b11ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import necessary libraries\n",
    "from sklearn.model_selection import train_test_split\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "\n",
    "from bigdl.orca import init_orca_context, stop_orca_context\n",
    "import bigdl.orca.data.pandas\n",
    "from bigdl.orca.data.transformer import *\n",
    "from bigdl.orca.learn.pytorch import Estimator\n",
    "from bigdl.orca.learn.metrics import Accuracy\n",
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e90082ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "# start an OrcaContext\n",
    "sc = init_orca_context(cores=4, memory=\"4g\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a46ad6a7",
   "metadata": {},
   "source": [
    "## Load data in parallel and get general information"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "904060e4",
   "metadata": {},
   "source": [
    "Load data into data_shards, it is a SparkXshards that can be operated on in parallel, here each element of the data_shards is a panda dataframe read from a file on the cluster. Users can distribute local code of `pd.read_csv(dataFile)` using `bigdl.orca.data.pandas.read_csv(datapath)`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7141367",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_shards = bigdl.orca.data.pandas.read_csv('../train.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6d12d912",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>feature_0</th>\n",
       "      <th>feature_1</th>\n",
       "      <th>feature_2</th>\n",
       "      <th>feature_3</th>\n",
       "      <th>feature_4</th>\n",
       "      <th>feature_5</th>\n",
       "      <th>feature_6</th>\n",
       "      <th>feature_7</th>\n",
       "      <th>feature_8</th>\n",
       "      <th>...</th>\n",
       "      <th>feature_41</th>\n",
       "      <th>feature_42</th>\n",
       "      <th>feature_43</th>\n",
       "      <th>feature_44</th>\n",
       "      <th>feature_45</th>\n",
       "      <th>feature_46</th>\n",
       "      <th>feature_47</th>\n",
       "      <th>feature_48</th>\n",
       "      <th>feature_49</th>\n",
       "      <th>target</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>21</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>Class_2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>Class_1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>13</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>Class_1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>Class_4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>Class_2</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 52 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   id  feature_0  feature_1  feature_2  feature_3  feature_4  feature_5  \\\n",
       "0   0          0          0          1          0          1          0   \n",
       "1   1          0          0          0          0          2          1   \n",
       "2   2          0          0          0          0          0          0   \n",
       "3   3          0          0          0          0          0          0   \n",
       "4   4          0          0          0          0          0          0   \n",
       "\n",
       "   feature_6  feature_7  feature_8  ...  feature_41  feature_42  feature_43  \\\n",
       "0          0          0          0  ...           0           0          21   \n",
       "1          0          0          0  ...           0           0           0   \n",
       "2          0          0          0  ...           0           1           0   \n",
       "3          0          3          0  ...           0           0           0   \n",
       "4          0          0          0  ...           0           0           0   \n",
       "\n",
       "   feature_44  feature_45  feature_46  feature_47  feature_48  feature_49  \\\n",
       "0           0           0           0           0           0           0   \n",
       "1           0           0           0           0           0           0   \n",
       "2           0           0           0          13           2           0   \n",
       "3           0           0           0           0           1           0   \n",
       "4           0           0           0           0           1           0   \n",
       "\n",
       "    target  \n",
       "0  Class_2  \n",
       "1  Class_1  \n",
       "2  Class_1  \n",
       "3  Class_4  \n",
       "4  Class_2  \n",
       "\n",
       "[5 rows x 52 columns]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# show the first couple of rows in the data_shards\n",
    "data_shards.head(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "3b8526d9",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "100000"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# count total number of rows in the data_shards\n",
    "len(data_shards)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "de53c793",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['id', 'feature_0', 'feature_1', 'feature_2', 'feature_3', 'feature_4', 'feature_5', 'feature_6', 'feature_7', 'feature_8', 'feature_9', 'feature_10', 'feature_11', 'feature_12', 'feature_13', 'feature_14', 'feature_15', 'feature_16', 'feature_17', 'feature_18', 'feature_19', 'feature_20', 'feature_21', 'feature_22', 'feature_23', 'feature_24', 'feature_25', 'feature_26', 'feature_27', 'feature_28', 'feature_29', 'feature_30', 'feature_31', 'feature_32', 'feature_33', 'feature_34', 'feature_35', 'feature_36', 'feature_37', 'feature_38', 'feature_39', 'feature_40', 'feature_41', 'feature_42', 'feature_43', 'feature_44', 'feature_45', 'feature_46', 'feature_47', 'feature_48', 'feature_49', 'target']\n"
     ]
    }
   ],
   "source": [
    "# columns information of element of data_shards.\n",
    "columns = data_shards.get_schema()['columns']\n",
    "print(columns)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ce183207",
   "metadata": {},
   "source": [
    "## Assemble feature and labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f062a1c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# drop duplicates in the data\n",
    "data_shards = data_shards.deduplicates()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef28781d",
   "metadata": {},
   "source": [
    "The labels are in strings. Users can transform the strings into integers using `StringIndexer`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f61203e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def change_col_name(df):\n",
    "    df = df.rename(columns={'id': 'id0'})\n",
    "    return df\n",
    "data_shards = data_shards.transform_shard(change_col_name)\n",
    "\n",
    "\n",
    "encode = StringIndexer(inputCol='target')\n",
    "data_shards = encode.fit_transform(data_shards)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6caf0c40",
   "metadata": {},
   "source": [
    "Labels start from 1 so need to be updated to zero based."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6208ecf",
   "metadata": {},
   "outputs": [],
   "source": [
    "def update_label_to_zero_base(df):\n",
    "    df['target'] = df['target'] - 1\n",
    "    return df\n",
    "data_shards = data_shards.transform_shard(update_label_to_zero_base)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff7e890e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def split_train_test(data):\n",
    "    RANDOM_STATE = 2021\n",
    "    train, test = train_test_split(data, test_size=0.2, random_state=RANDOM_STATE)\n",
    "    return train, test\n",
    "train_shards, val_shards = data_shards.transform_shard(split_train_test).split()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d94e15e0",
   "metadata": {},
   "source": [
    "Users can call `train_shards.select('target')` to select 'target' data, then call `sample_to_pdf` to bring data cross the cluster to local Pandas Dataframe,  and use seaborn to check the same class representation in the dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "e8ac8915",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "     target\n",
      "160       1\n",
      "375       0\n",
      "446       1\n",
      "305       1\n",
      "137       1\n",
      "223       1\n",
      "323       1\n",
      "23        1\n",
      "316       0\n",
      "424       0\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<AxesSubplot:xlabel='target', ylabel='count'>"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZEAAAEGCAYAAACkQqisAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAARfUlEQVR4nO3df6xkZX3H8feHXRCtIig3FHepS3Rjs1JF2SKVRFuIulAV4q9AVLaWum1Eo0ljxaap9QeNplbqDzQhZRWMEVFsQYIhG0SJRoRFENillCtqWYLuyk/RoF389o95FqfrXbw83Jm5s/t+JZN7zvc8Z+Y7EzYfzpwzz0lVIUlSj70m3YAkaXoZIpKkboaIJKmbISJJ6maISJK6LZ10A+N24IEH1ooVKybdhiRNjWuvvfanVTUz17Y9LkRWrFjBxo0bJ92GJE2NJD/a1Ta/zpIkdTNEJEndDBFJUjdDRJLUzRCRJHUzRCRJ3QwRSVI3Q0SS1M0QkSR12+N+sa7x+p/3/dGkW1g0/uAfb5x0C9KC80hEktTNEJEkdTNEJEndDBFJUjdDRJLUzRCRJHUzRCRJ3QwRSVI3Q0SS1M0QkSR1M0QkSd0MEUlSN0NEktTNEJEkdTNEJEndDBFJUjdDRJLUzRCRJHUzRCRJ3QwRSVI3Q0SS1M0QkSR1M0QkSd1GHiJJliS5Lsklbf3QJN9JMpvkC0n2afXHtfXZtn3F0HO8u9VvSfKyofqaVptNcvqo34sk6f8bx5HI24Gbh9Y/BJxZVc8E7gFObfVTgXta/cw2jiSrgJOAZwNrgE+2YFoCnAUcB6wCTm5jJUljMtIQSbIc+HPg39t6gGOAL7Uh5wIntuUT2jpt+7Ft/AnA+VX1y6r6ATALHNkes1V1W1X9Cji/jZUkjcmoj0T+Dfg74Ndt/anAvVW1va1vAZa15WXA7QBt+31t/MP1nfbZVf23JFmXZGOSjdu2bXuMb0mStMPIQiTJy4GtVXXtqF5jvqrq7KpaXVWrZ2ZmJt2OJO02lo7wuY8GXpnkeGBfYD/go8D+SZa2o43lwB1t/B3AIcCWJEuBJwN3DdV3GN5nV3VJ0hiM7Eikqt5dVcuragWDE+Nfq6rXA1cAr2nD1gIXteWL2zpt+9eqqlr9pHb11qHASuBq4BpgZbvaa5/2GheP6v1Ikn7bKI9EduVdwPlJPgBcB5zT6ucAn00yC9zNIBSoqk1JLgA2A9uB06rqIYAkbwUuA5YA66tq01jfiSTt4cYSIlX1deDrbfk2BldW7TzmQeC1u9j/DOCMOeqXApcuYKuSpEfBX6xLkroZIpKkboaIJKmbISJJ6maISJK6GSKSpG6GiCSpmyEiSepmiEiSuhkikqRuhogkqZshIknqZohIkroZIpKkboaIJKmbISJJ6maISJK6GSKSpG6GiCSpmyEiSepmiEiSuhkikqRuhogkqZshIknqZohIkroZIpKkboaIJKmbISJJ6maISJK6GSKSpG6GiCSpmyEiSepmiEiSuhkikqRuhogkqZshIknqNrIQSbJvkquTfC/JpiTvbfVDk3wnyWySLyTZp9Uf19Zn2/YVQ8/17la/JcnLhuprWm02yemjei+SpLmN8kjkl8AxVfVc4HBgTZKjgA8BZ1bVM4F7gFPb+FOBe1r9zDaOJKuAk4BnA2uATyZZkmQJcBZwHLAKOLmNlSSNychCpAYeaKt7t0cBxwBfavVzgRPb8gltnbb92CRp9fOr6pdV9QNgFjiyPWar6raq+hVwfhsrSRqTkZ4TaUcM1wNbgQ3A94F7q2p7G7IFWNaWlwG3A7Tt9wFPHa7vtM+u6nP1sS7JxiQbt23btgDvTJIEIw6Rqnqoqg4HljM4cvjDUb7eI/RxdlWtrqrVMzMzk2hBknZLY7k6q6ruBa4A/gTYP8nStmk5cEdbvgM4BKBtfzJw13B9p312VZckjckor86aSbJ/W3488BLgZgZh8po2bC1wUVu+uK3Ttn+tqqrVT2pXbx0KrASuBq4BVrarvfZhcPL94lG9H0nSb1v6u4d0Oxg4t11FtRdwQVVdkmQzcH6SDwDXAee08ecAn00yC9zNIBSoqk1JLgA2A9uB06rqIYAkbwUuA5YA66tq0wjfjyRpJyMLkaq6AXjeHPXbGJwf2bn+IPDaXTzXGcAZc9QvBS59zM1Kkrr4i3VJUjdDRJLUzRCRJHUzRCRJ3QwRSVI3Q0SS1M0QkSR1M0QkSd0MEUlSt3mFSJLL51OTJO1ZHnHakyT7Ak8ADkxyAJC2aT92ce8OSdKe43fNnfXXwDuApwHX8psQuR/4xOjakiRNg0cMkar6KPDRJG+rqo+PqSdJ0pSY1yy+VfXxJC8EVgzvU1XnjagvSdIUmFeIJPks8AzgeuChVi7AEJGkPdh87yeyGljV7jQoSRIw/9+J3AT8/igbkSRNn/keiRwIbE5yNfDLHcWqeuVIupIkTYX5hsg/jbIJSdJ0mu/VWd8YdSOSpOkz36uzfsbgaiyAfYC9gZ9X1X6jakyStPjN90jkSTuWkwQ4AThqVE1JkqbDo57Ftwb+E3jZwrcjSZom8/0661VDq3sx+N3IgyPpSJI0NeZ7ddYrhpa3Az9k8JWWJGkPNt9zIm8adSOSpOkz35tSLU/yH0m2tseFSZaPujlJ0uI23xPrnwYuZnBfkacBX2k1SdIebL4hMlNVn66q7e3xGWBmhH1JkqbAfEPkriRvSLKkPd4A3DXKxiRJi998Q+QvgdcBPwbuBF4D/MWIepIkTYn5XuL7PmBtVd0DkOQpwIcZhIskaQ813yOR5+wIEICquht43mhakiRNi/mGyF5JDtix0o5E5nsUI0naTc03CP4V+HaSL7b11wJnjKYlSdK0mO8v1s9LshE4ppVeVVWbR9eWJGkazHsW36raXFWfaI/fGSBJDklyRZLNSTYleXurPyXJhiS3tr8HtHqSfCzJbJIbkjx/6LnWtvG3Jlk7VD8iyY1tn4+1aeolSWPyqKeCfxS2A39bVasY3HvktCSrgNOBy6tqJXB5Wwc4DljZHuuAT8HD51/eA7wAOBJ4z9D5mU8Bbx7ab80I348kaScjC5GqurOqvtuWfwbcDCxjMPvvuW3YucCJbfkE4Lx2v5KrgP2THMzgviUbqurudoXYBmBN27ZfVV1VVQWcN/RckqQxGOWRyMOSrGBwSfB3gIOq6s626cfAQW15GXD70G5bWu2R6lvmqM/1+uuSbEyycdu2bY/tzUiSHjbyEEnyROBC4B1Vdf/wtnYEUXPuuICq6uyqWl1Vq2dmnPJLkhbKSEMkyd4MAuRzVfXlVv5J+yqK9ndrq98BHDK0+/JWe6T68jnqkqQxGVmItCulzgFurqqPDG26GNhxhdVa4KKh+intKq2jgPva116XAS9NckA7of5S4LK27f4kR7XXOmXouSRJYzDKX50fDbwRuDHJ9a3298AHgQuSnAr8iMHEjgCXAscDs8AvgDfBYIqVJO8Hrmnj3temXQF4C/AZ4PHAV9tDkjQmIwuRqvomsKvfbRw7x/gCTtvFc60H1s9R3wgc9hjalCQ9BmO5OkuStHsyRCRJ3QwRSVI3Q0SS1M0QkSR1M0QkSd0MEUlSN29xu5Mj3nnepFtYNK79l1Mm3YKkRc4jEUlSN0NEktTNEJEkdTNEJEndDBFJUjdDRJLUzRCRJHUzRCRJ3QwRSVI3Q0SS1M0QkSR1M0QkSd0MEUlSN0NEktTNEJEkdTNEJEndDBFJUjdDRJLUzRCRJHUzRCRJ3QwRSVI3Q0SS1M0QkSR1M0QkSd0MEUlSN0NEktTNEJEkdTNEJEndRhYiSdYn2ZrkpqHaU5JsSHJr+3tAqyfJx5LMJrkhyfOH9lnbxt+aZO1Q/YgkN7Z9PpYko3ovkqS5jfJI5DPAmp1qpwOXV9VK4PK2DnAcsLI91gGfgkHoAO8BXgAcCbxnR/C0MW8e2m/n15IkjdjSUT1xVV2ZZMVO5ROAP23L5wJfB97V6udVVQFXJdk/ycFt7IaquhsgyQZgTZKvA/tV1VWtfh5wIvDVUb0fadKO/vjRk25h0fjW27416RbUjPucyEFVdWdb/jFwUFteBtw+NG5Lqz1Sfcsc9TklWZdkY5KN27Zte2zvQJL0sImdWG9HHTWm1zq7qlZX1eqZmZlxvKQk7RHGHSI/aV9T0f5ubfU7gEOGxi1vtUeqL5+jLkkao3GHyMXAjius1gIXDdVPaVdpHQXc1772ugx4aZID2gn1lwKXtW33JzmqXZV1ytBzSZLGZGQn1pN8nsGJ8QOTbGFwldUHgQuSnAr8CHhdG34pcDwwC/wCeBNAVd2d5P3ANW3c+3acZAfewuAKsMczOKHuSXVJGrNRXp118i42HTvH2AJO28XzrAfWz1HfCBz2WHqUJD02/mJdktTNEJEkdTNEJEndDBFJUjdDRJLUzRCRJHUzRCRJ3QwRSVK3kf3YUJIWs2+86MWTbmHRePGV3+je1yMRSVI3Q0SS1M0QkSR1M0QkSd0MEUlSN0NEktTNEJEkdTNEJEndDBFJUjdDRJLUzRCRJHUzRCRJ3QwRSVI3Q0SS1M0QkSR1M0QkSd0MEUlSN0NEktTNEJEkdTNEJEndDBFJUjdDRJLUzRCRJHUzRCRJ3QwRSVI3Q0SS1M0QkSR1m/oQSbImyS1JZpOcPul+JGlPMtUhkmQJcBZwHLAKODnJqsl2JUl7jqkOEeBIYLaqbquqXwHnAydMuCdJ2mOkqibdQ7ckrwHWVNVftfU3Ai+oqrfuNG4dsK6tPgu4ZayNPnoHAj+ddBO7ET/PheXnubCm4fN8elXNzLVh6bg7mYSqOhs4e9J9zFeSjVW1etJ97C78PBeWn+fCmvbPc9q/zroDOGRofXmrSZLGYNpD5BpgZZJDk+wDnARcPOGeJGmPMdVfZ1XV9iRvBS4DlgDrq2rThNtaCFPz1duU8PNcWH6eC2uqP8+pPrEuSZqsaf86S5I0QYaIJKmbIbLIOI3LwkmyPsnWJDdNupdpl+SQJFck2ZxkU5K3T7qnaZZk3yRXJ/le+zzfO+meenlOZBFp07j8N/ASYAuDq89OrqrNE21sSiV5EfAAcF5VHTbpfqZZkoOBg6vqu0meBFwLnOh/m32SBPi9qnogyd7AN4G3V9VVE27tUfNIZHFxGpcFVFVXAndPuo/dQVXdWVXfbcs/A24Glk22q+lVAw+01b3bYyr/j94QWVyWAbcPrW/Bf6haZJKsAJ4HfGfCrUy1JEuSXA9sBTZU1VR+noaIpHlL8kTgQuAdVXX/pPuZZlX1UFUdzmCmjSOTTOVXrobI4uI0Llq02nf3FwKfq6ovT7qf3UVV3QtcAayZcCtdDJHFxWlctCi1E8HnADdX1Ucm3c+0SzKTZP+2/HgGF9P810Sb6mSILCJVtR3YMY3LzcAFu8k0LhOR5PPAt4FnJdmS5NRJ9zTFjgbeCByT5Pr2OH7STU2xg4ErktzA4H8eN1TVJRPuqYuX+EqSunkkIknqZohIkroZIpKkboaIJKmbISJJ6maISAsoyf5J3jKG1zkxyapRv470uxgi0sLaH5h3iGSg59/hiYAhoonzdyLSAkqyY+blWxhMZfEc4AAGs7T+Q1Vd1CYwvIzBBIZHAMcDpwBvALYxmITz2qr6cJJnAGcBM8AvgDcDTwEuAe5rj1dX1ffH9R6lYUsn3YC0mzkdOKyqDk+yFHhCVd2f5EDgqiQ7prFZCaytqquS/DHwauC5DMLmuwzu1wFwNvA3VXVrkhcAn6yqY9rzXFJVXxrnm5N2ZohIoxPgn9vNsX7NYFr/g9q2Hw3dgOho4KKqehB4MMlX4OEZc18IfHEwdRUAjxtX89J8GCLS6LyewddQR1TV/yb5IbBv2/bzeey/F3Bvmy5cWpQ8sS4trJ8BT2rLTwa2tgD5M+Dpu9jnW8Ar2n23nwi8HKDdr+MHSV4LD5+Ef+4cryNNjCEiLaCqugv4VpKbgMOB1UluZHDifM6pvqvqGgZT/t8AfBW4kcEJcxgczZya5HvAJn5zu+TzgXcmua6dfJcmwquzpEUgyROr6oEkTwCuBNbtuKe5tJh5TkRaHM5uPx7cFzjXANG08EhEktTNcyKSpG6GiCSpmyEiSepmiEiSuhkikqRu/wc5Xa805Nl9GwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "y = train_shards.select(\"target\").sample_to_pdf(frac=1.0)\n",
    "print(y.head(10))\n",
    "sns.countplot(x='target', data=y)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d453c904",
   "metadata": {},
   "source": [
    "Normalize the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abeb51a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_list = []\n",
    "for i in range(50):\n",
    "    feature_list.append('feature_' + str(i))\n",
    "scale = MinMaxScaler(inputCol=feature_list, outputCol=\"x_scaled\")\n",
    "train_shards = scale.fit_transform(train_shards)\n",
    "val_shards = scale.transform(val_shards)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f9494cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Change data types\n",
    "def change_data_type(df):\n",
    "    df['x_scaled'] = df['x_scaled'].apply(lambda x: np.array(x, dtype=np.float32))\n",
    "    df['target'] = df['target'].apply(lambda x: int(x))\n",
    "    return df\n",
    "\n",
    "train_shards = train_shards.transform_shard(change_data_type)\n",
    "val_shards = val_shards.transform_shard(change_data_type)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5cf45f4",
   "metadata": {},
   "source": [
    "## Define PyTorch model and train it"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b6318a0a",
   "metadata": {},
   "source": [
    "Users can build a PyTorch model as usual and use Orca Estimator to train it"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "d6e4a182",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# define a PyTorch model\n",
    "torch.manual_seed(0)\n",
    "BATCH_SIZE = 64\n",
    "NUM_CLASSES = 4\n",
    "NUM_EPOCHS = 1\n",
    "NUM_FEATURES = 50\n",
    "\n",
    "\n",
    "def linear_block(in_features, out_features, p_drop, *args, **kwargs):\n",
    "    return nn.Sequential(\n",
    "        nn.Linear(in_features, out_features),\n",
    "        nn.ReLU(),\n",
    "        nn.Dropout(p=p_drop)\n",
    "    )\n",
    "\n",
    "\n",
    "class TPS05ClassificationSeq(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(TPS05ClassificationSeq, self).__init__()\n",
    "        num_feature = NUM_FEATURES\n",
    "        num_class = NUM_CLASSES\n",
    "        self.linear = nn.Sequential(\n",
    "            linear_block(num_feature, 100, 0.3),\n",
    "            linear_block(100, 250, 0.3),\n",
    "            linear_block(250, 128, 0.3),\n",
    "        )\n",
    "\n",
    "        self.out = nn.Sequential(\n",
    "            nn.Linear(128, num_class)\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.linear(x)\n",
    "        return self.out(x)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "eb845acd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# define criterion and optimizer\n",
    "def optim_creator(model, config):\n",
    "    return optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "criterion = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a37426d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# build Orca Estimator taking a model_creator\n",
    "def model_creator(config):\n",
    "    model = TPS05ClassificationSeq()\n",
    "    print(model)\n",
    "    return model\n",
    "\n",
    "est = Estimator.from_torch(model=model_creator, optimizer=optim_creator,\n",
    "                           loss=criterion, metrics=[Accuracy()], backend=\"ray\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c84d0a81",
   "metadata": {},
   "outputs": [],
   "source": [
    "# train the model\n",
    "est.fit(data=train_shards, feature_cols=['x_scaled'], label_cols=['target'],\n",
    "        validation_data=val_shards, epochs=NUM_EPOCHS, batch_size=BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b85c1e89",
   "metadata": {},
   "outputs": [],
   "source": [
    "# evaluate the model\n",
    "result = est.evaluate(data=val_shards, feature_cols=['x_scaled'],\n",
    "                      label_cols=['target'], batch_size=BATCH_SIZE)\n",
    "\n",
    "for r in result:\n",
    "    print(r, \":\", result[r])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b979da67",
   "metadata": {},
   "outputs": [],
   "source": [
    "# stop OrcaContext\n",
    "stop_orca_context()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py37tf2_x3",
   "language": "python",
   "name": "py37tf2_x3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
